import numpy as np
import networkx as nx

import pickle

from geometric_clustering import compute_curvatures, cluster_signed_modularity
from sklearn.cluster import SpectralClustering

import cdlib
from cdlib.algorithms import girvan_newman, belief, louvain

import time

if __name__ == "__main__":
    
    
    res = pickle.load(open('SBM_benchmark.pkl','rb')) 
    
    for i, G in enumerate(res['graphs']):

        print(i)
        #compute curvature
        print('Cmputing curvatures')
        now = time.time()
        times = np.logspace(-.5, 0, 25)
        epsilon = np.finfo(float).eps
        kappas = compute_curvatures(G, times, n_workers=14, use_spectral_gap=True, measure_cutoff=epsilon)
        print(time.time()-now)
        
        # if res['curvatures'][i] is not None:
        #     kappas = res['curvatures'][i]
        
        #geometric modularity clustering
        print('Running geometric modularity')
        geometric_modularity = cluster_signed_modularity(
            G,
            times,
            kappas,
            kappa0=0.0,
            n_louvain=50,
            n_louvain_VI=50,
            with_postprocessing=True,
            n_workers=14
            )
                
        #greedy modularity
        print('Running modularity')
        mod = louvain(G)
        
        #spectral clustering
        print('Running spectral clustering')
        clusters = SpectralClustering(n_clusters=2,
            assign_labels="discretize", affinity='precomputed',
            random_state=0).fit(nx.adjacency_matrix(G))
        
        C_1 = [n for n in G.nodes if clusters.labels_[n]==0]
        C_2 = [n for n in G.nodes if clusters.labels_[n]==1]
        spectral = cdlib.NodeClustering(communities=[C_1,C_2], graph=G, method_name="reference")
        
        #Girvan-Newman
        print('Running Girvan-Newman')
        GN = girvan_newman(G,level=1)
        
        #belief propagation
        print('Running belief propagation')
        now = time.time()
        belief_propagation = belief(G)
        print(time.time()-now)
        
        #collect results
        if 'curvatures' not in res.keys():
            res['curvatures'] = []
        res['curvatures'].append(kappas)
        
        if 'geometric_modularity' not in res.keys():
            res['geometric_modularity'] = []
        res['geometric_modularity'].append(geometric_modularity)
        # res['geometric_modularity'][i] = geometric_modularity
        
        if 'modularity' not in res.keys():
            res['modularity'] = []
        res['modularity'].append(mod)
        
        if 'spectral' not in res.keys():
            res['spectral'] = []
        res['spectral'].append(spectral)
        
        if 'GN' not in res.keys():
            res['GN'] = []
        res['GN'].append(GN)
        
        if 'belief_propagation' not in res.keys():
            res['belief_propagation'] = []
        res['belief_propagation'].append(belief_propagation)
        
        #saving
        pickle.dump(res, open('SBM_benchmark.pkl','wb'))
    